Abstract

We are doing this analysis to develop algorithms that will aid medical practitioners with additional intelligence about a patients heart condition by predicting whether a patient will develop heart disease or not (given a certain set of conditions). We did an indepth analysis on the importance of each predictor on the variable of interest (whether the patient has heart disease or not). We tried 4 different algorithms and did an ensemble of all four techniques. The algorithms developed gave fairly stable predictions and performed significantly better than the benchmark (probability of the majority class). The results from this excericse can be used by doctors to augment their decision about a patients health condition.


Introduction

In this analysis we aim to build an algorithm and identify features that help in predicting whether a patient has the likelihood of developing heart disease or not.The data that we are using for this analysis was accessed through the UCI Machine Learning repository. The parent data set contains 75 attributes (columns) in total. However, for this analysis we will be using only a subset of these attributes. The detailed description of the data set is provided in the appendix and the analysis below explains each column in detail.


Methods

Data

Let’s check the size of the data set first.

print(paste("Number of rows in the dataset:",nrow(hd)))
## [1] "Number of rows in the dataset: 920"
print(paste("Number of columns in the dataset:",ncol(hd)))
## [1] "Number of columns in the dataset: 15"

Each row in the data set represents one patient with different features recorded during the test. The name and types of different columns in the data set is:

print(colnames(hd))
##  [1] "age"      "sex"      "cp"       "trestbps" "chol"     "fbs"     
##  [7] "restecg"  "thalach"  "exang"    "oldpeak"  "slope"    "ca"      
## [13] "thal"     "num"      "location"
print(str(hd))
## tibble [920 x 15] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ age     : num [1:920] 63 67 67 37 41 56 62 57 63 53 ...
##  $ sex     : num [1:920] 1 1 1 1 0 1 0 0 1 1 ...
##  $ cp      : num [1:920] 1 4 4 3 2 2 4 4 4 4 ...
##  $ trestbps: num [1:920] 145 160 120 130 130 120 140 120 130 140 ...
##  $ chol    : num [1:920] 233 286 229 250 204 236 268 354 254 203 ...
##  $ fbs     : num [1:920] 1 0 0 0 0 0 0 0 0 1 ...
##  $ restecg : num [1:920] 2 2 2 0 2 0 2 0 2 2 ...
##  $ thalach : num [1:920] 150 108 129 187 172 178 160 163 147 155 ...
##  $ exang   : num [1:920] 0 1 1 0 0 0 0 1 0 1 ...
##  $ oldpeak : num [1:920] 2.3 1.5 2.6 3.5 1.4 0.8 3.6 0.6 1.4 3.1 ...
##  $ slope   : num [1:920] 3 2 2 3 1 1 3 1 2 3 ...
##  $ ca      : num [1:920] 0 3 2 0 0 0 2 0 1 0 ...
##  $ thal    : num [1:920] 6 3 7 3 3 3 3 3 7 7 ...
##  $ num     : chr [1:920] "v0" "v2" "v1" "v0" ...
##  $ location: chr [1:920] "cl" "cl" "cl" "cl" ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   age = col_double(),
##   ..   sex = col_double(),
##   ..   cp = col_double(),
##   ..   trestbps = col_double(),
##   ..   chol = col_double(),
##   ..   fbs = col_double(),
##   ..   restecg = col_double(),
##   ..   thalach = col_double(),
##   ..   exang = col_double(),
##   ..   oldpeak = col_double(),
##   ..   slope = col_double(),
##   ..   ca = col_double(),
##   ..   thal = col_double(),
##   ..   num = col_character(),
##   ..   location = col_character()
##   .. )
## NULL

Each column in the data set represets different features that were recorded across patients for this experiment. The detailed description for each column can be found in the Appendix section.

The purpose of this anlysis is: given a certain set of conditions, predict whether a patient has a heart disease or not. In our data set we have a column called num which contains the following values:

unique(hd$num)
## [1] "v0" "v2" "v1" "v3" "v4"

The meaning of these values is listed below:

  1. v0 : 0 major vessels with greater than 50% diameter narrowing. No presence of heart disease.
  2. v1 : 1 major vessels with greater than 50% diameter narrowing
  3. v2 : 2 major vessels with greater than 50% diameter narrowing
  4. v3 : 3 major vessels with greater than 50% diameter narrowing
  5. v4 : 3 major vessels with greater than 50% diameter narrowing

For the purpose of this analysis, if a row (representing a patient) has v0 value in the num column, we consider it as an indicator of no heart disease. If the row has any other value i.e. v1, v2, v3, v4 , we will consider it as an indicator of heart disease. Before we jump into data exploration let’s convert the v4 column into a binary class variable

# Creating a binary flag for heart disease from the num column and then converting it into a factor
hd = hd %>% mutate(hd_ind = case_when(
                   num == "v0" ~ 0,
                   num == "v1" ~ 1,
                   num == "v2" ~ 1,
                   num == "v3" ~ 1,
                   num == "v4" ~ 1)) %>% mutate(hd_ind = as.factor(hd_ind)) %>% select(-num)

The code chunk above creates an hd_ind flag (binary) and then converts it into a factor. Going forward we will be using this flag for predicting whether a patient (given a certain set of conditions) has heart disease or not. Printing the first 5 rows of the data to see how it looks after the above pre-processing has been done.

print(head(as_tibble(hd),5), width = Inf)
## # A tibble: 5 x 15
##     age   sex    cp trestbps  chol   fbs restecg thalach exang oldpeak slope
##   <dbl> <dbl> <dbl>    <dbl> <dbl> <dbl>   <dbl>   <dbl> <dbl>   <dbl> <dbl>
## 1    63     1     1      145   233     1       2     150     0     2.3     3
## 2    67     1     4      160   286     0       2     108     1     1.5     2
## 3    67     1     4      120   229     0       2     129     1     2.6     2
## 4    37     1     3      130   250     0       0     187     0     3.5     3
## 5    41     0     2      130   204     0       2     172     0     1.4     1
##      ca  thal location hd_ind
##   <dbl> <dbl> <chr>    <fct> 
## 1     0     6 cl       0     
## 2     3     3 cl       1     
## 3     2     7 cl       1     
## 4     0     3 cl       0     
## 5     0     3 cl       0

Data Exploration

Now let’s begin with the data exploration. First lets check the proportion of missing values in every column of the data frame.

na_prop = function(x) {
  mean(is.na(x))
}

print(sapply(hd, na_prop))
##         age         sex          cp    trestbps        chol         fbs 
## 0.000000000 0.000000000 0.000000000 0.064130435 0.032608696 0.097826087 
##     restecg     thalach       exang     oldpeak       slope          ca 
## 0.002173913 0.059782609 0.059782609 0.067391304 0.335869565 0.664130435 
##        thal    location      hd_ind 
## 0.528260870 0.000000000 0.000000000

All the columns seem fairly well populated except slope, ca and thal. We will be dropping these columns from our analysis as they don’t have enough information to contribute to our model in a meaningful way. After that we go ahead and remove all rows that have any missing values. Once these two data cleaning steps are done we print the dimensions of the resultant data sets.

# create dataset without columns containing more than 33% NAs
hd_clean = na.omit(hd[, !sapply(hd, na_prop) > 0.33])
print(paste("Number of rows in the dataset:",nrow(hd_clean)))
## [1] "Number of rows in the dataset: 740"
print(paste("Number of columns in the dataset:",ncol(hd_clean)))
## [1] "Number of columns in the dataset: 12"

Checking for class imbalance in the data set by creating a frequency plot below. Please keep the following definition in mind while reading the chart below:

  • Value 0 : No Disease
  • Value 1 : Disease
# Summarizing the data
summ = hd_clean %>% group_by(hd_ind) %>% summarise(n = n())

# Making the plot
fig = plot_ly(summ, 
              x = ~hd_ind, 
              y = ~n,
              color = ~hd_ind,
              colors = c("#34eb89","#eb4034"),
              text = summ$hd_ind,
              textposition = 'auto',
              type = "bar", 
              name = c("No Disease", "Disease"))

# Changing the layout
fig = fig %>% layout(yaxis = list(title = 'Number of Patients'), 
                     xaxis = list(title = 'Disease'),
                     title = "Class Frequency Plot")%>% 
  layout(showlegend = TRUE)

# Displaying the figure
fig
  • We can clearly that both classes have similar presence in the dataset. Due to this, it will be a good idea to use accuracy as an evaluation metric once we are done building it

Next Lets check what is the prevalence of heart disease in patients with respect to their age

# Creating two separate groups for heart disease and no heart disease and then plotting it on overlaid histograms
dis_grp = hd_clean %>% filter(hd_ind == 1)
no_dis_grp = hd_clean %>% filter(hd_ind == 0)

chart_title = "Histogram of age for group with Heart Disease and No Heart Disease"

fig2 = plot_hist_funct(dis_grp$age, no_dis_grp$age, chart_title, "Age", "count")

fig2
  • The graph above shows that with increasing age there is clear indication of increasing heart disease (i.e. we see a higher proportion of people with heart disease in the higher age brackets)

Next let’s check the prevalence of heart disease across sex(sex). We know the following about the sex field:

  • Value 0 : Female
  • Value 1 : Male
summ_gender = hd_clean %>% mutate(sex = as.factor(sex)) %>% mutate(sex = case_when(
  sex == 0 ~ "Female",
  sex == 1 ~ "Male"
)) %>% 
  group_by(sex,hd_ind) %>% summarise(n = n()) %>% pivot_wider(names_from = hd_ind, values_from = n) 

colnames(summ_gender) = c("sex","No_Disease", "Disease")

fig3 = plot_ly(summ_gender, 
              x = ~summ_gender$sex, 
              y = ~summ_gender$No_Disease, 
              type = 'bar', 
              name = 'No Disease',
              marker = list(color = c("#34eb89"))) %>% 
  add_trace(y = ~summ_gender$Disease, 
            name = 'Disease',
            marker = list(color = c("#eb4034"))) %>% 
  layout(yaxis = list(title = 'Count'),
         title = "Presence of Heart Disease across genders",
         xaxis = list(title = 'Gender'), barmode = 'stack')

fig3
  • We can see from the above chart that the male population is more susceptible to heart disease

Next variable in order is Chest Pain (cp). The detailed description of the values contained within the column is as follows:

  • Value 1: typical angina
  • Value 2: atypical angina
  • Value 3: non-anginal pain (pain without any relation to angina)
  • Value 4: Asymptomatic

For anyone who does not have a background in medicine. Below are some definitions which will make the interpretation easier:

  • Angina: Chest pain caused due to reduction in blood flow to the coronary arteries
  • Typical & Atypical angina: Typical angina usually means a chest discomfort. That said, some people can experience other symptoms like nausea or shortness of breath. In such cases angina is classified as atypical angina.

Now let’s look at the data and see how is heart diseases prevalant across patients who have experienced different chest pain types

summ_cp = hd_clean %>% mutate(cp = as.factor(cp)) %>% mutate(cp = case_when(
  cp == 1 ~ "Typical Angina",
  cp == 2 ~ "Atypical Angina",
  cp == 3 ~ "Non Anginal Pain",
  cp == 4 ~ "Asymptomatic"
)) %>% 
  group_by(cp,hd_ind) %>% summarise(n = n()) %>% pivot_wider(names_from = hd_ind, values_from = n) 

colnames(summ_cp) = c("Chest_Pain","No_Disease", "Disease")

fig4 = plot_ly(summ_cp, 
              x = ~summ_cp$Chest_Pain, 
              y = ~summ_cp$No_Disease, 
              type = 'bar', 
              name = 'No Disease',
              marker = list(color = c("#34eb89"))) %>% 
  add_trace(y = ~summ_cp$Disease, 
            name = 'Disease',
            marker = list(color = c("#eb4034"))) %>% 
  layout(yaxis = list(title = 'Count'),
         title = "Presence of Heart Disease across Anginal Pain types",
         xaxis = list(title = 'Pain Type'), barmode = 'stack')

fig4
  • From the chart above we see that a major chunk of the diseased patients is asymptomatic with respect to the anginal pain experience. Thus there is no clear evidence whether experiencing chest pain is a clear indicator of heart disease.

Next we will be looking into the Resting Blood Pressure(trestbps) of different patients and its relationship with presence of heart disease.

chart_title = "Histogram of Resting Blood Pressue (Heart Disease vs. No Heart Disease)"

fig5 = plot_hist_funct(dis_grp$trestbps, 
                       no_dis_grp$trestbps, 
                       chart_title, 
                       "Resting B.P.(mm Hg)",
                       "count")

fig5
  • The graph above tells us that most of the patients in the data set have normal blood pressure. However, the proportion of diseased patients increases in group that has blood pressure > 150

  • We also see an outlier in the histogram (left most point on the x-axis), we will be taking that data point out from the analysis

Next we will be looking into the Cholestrol Level (chol) of different patients and its relationship with presence of heart disease.

chart_title = "Histogram of Cholestrol (Heart Disease vs. No Heart Disease)"
fig6 = plot_hist_funct(dis_grp$chol,
                       no_dis_grp$chol,
                       chart_title,
                       "Cholestrol(mg/dl)",
                       "count")

fig6
  • There is no clear indication of Cholestrol causing heart diseases from this data set
  • Although we do catch a data anomaly
  • The far left of the graph above shows that around 79 patients had 0 cholestrol when the measurements were taken

Next variable in order is Fasting Blood Sugar (fbs). This column contains two values (1,0) and can be read as follows:

  • Value 1 : Fasting blood sugar > 120 mg/dl
  • Value 0 : Fasting blood sugar < 120 mg/dl

Checking the presence of heart disease across blood sugar level:

summ_fbs = hd_clean %>% mutate(fbs = as.factor(fbs)) %>% mutate(fbs = case_when(
  fbs == 1 ~ "> 120 mg/dl",
  fbs == 0 ~ "< 120 mg/dl"
)) %>% 
  group_by(fbs,hd_ind) %>% summarise(n = n()) %>% pivot_wider(names_from = hd_ind, values_from = n) 

colnames(summ_fbs) = c("Blood_Sugar","No_Disease", "Disease")


fig7 = plot_ly(summ_fbs, 
              x = ~summ_fbs$Blood_Sugar, 
              y = ~summ_fbs$No_Disease, 
              type = 'bar', 
              name = 'No Disease',
              marker = list(color = c("#34eb89"))) %>% 
  add_trace(y = ~summ_fbs$Disease, 
            name = 'Disease',
            marker = list(color = c("#eb4034"))) %>% 
  layout(yaxis = list(title = 'Count'),
         title = "Presence of Heart Disease across Blood Sugar Levels",
         xaxis = list(title = 'Blood Sugar Level'), barmode = 'stack')

fig7

This variable does not seem very useful in determining whether a patient has heart disease or not as the proportion of diseased vs. non diseased patients is same across both the categories.

Next we will be looking at Resting Electrocardiographic Results(restecg). The levels present in this variable can be interpreted as follows:

  • Value 0 : Normal
  • Value 1 : Having ST - T Wave abnormality
  • Value 2 : Probably having left ventricular hypertrophy

Let’s look at the split of diseased vs non diseased categories across the three segments:

summ_restecg = hd_clean %>% mutate(restecg = as.factor(restecg)) %>% mutate(restecg = case_when(
  restecg == 2 ~ "Ventricular Hypertrophy",
  restecg == 1 ~ "ST - T Wave abnormality",
  restecg == 0 ~ "Normal"
)) %>% 
  group_by(restecg,hd_ind) %>% summarise(n = n()) %>% pivot_wider(names_from = hd_ind, values_from = n) 

colnames(summ_restecg) = c("Rest Ecg","No_Disease", "Disease")

fig8 = plot_ly(summ_restecg, 
              x = ~summ_restecg$`Rest Ecg`, 
              y = ~summ_restecg$No_Disease, 
              type = 'bar', 
              name = 'No Disease',
              marker = list(color = c("#34eb89"))) %>% 
  add_trace(y = ~summ_restecg$Disease, 
            name = 'Disease',
            marker = list(color = c("#eb4034"))) %>% 
  layout(yaxis = list(title = 'Count'),
         title = "Presence of Heart Disease across Rest ECG Categories",
         xaxis = list(title = 'Rest ECG'), barmode = 'stack')

fig8
  • We see that the presence of diseased patients is similar across the three categories
  • This shows that the resting ecg is not a very good indicator of presence of heart disease

Next, lets look at the Maximum Heart rate achieved(thalch) by any patient during exercise. Let’s look at the histogram of heart rates across our categories of interest.

chart_title = "Histogram of Maximum Heart Rate during excercise(Disease vs Non Disease)"

fig9 = plot_hist_funct(dis_grp$thalach, 
                       no_dis_grp$thalach, 
                       chart_title,
                       "Maximum Heart Rate During Excercise",
                       "Count")
fig9
  • Prima facie, the graph above looks weird, since we are seeing people with higher heart rate showing lower presence of heart disease

  • Lets look at the average age of the participants having thalch > 140 and thalch <= 140 and check if age could be factor behind lower presence of heart disease in people with higher rate

Age distribution in group 1 (thalch > 140):

as_tibble(

  hd_clean %>% 
    select(age, thalach) %>% 
    filter(thalach > 140) %>% 
    summarise(avg_age = mean(age),
              min_age = min(age),
              max_age = max(age),
              patients = n()))
## # A tibble: 1 x 4
##   avg_age min_age max_age patients
##     <dbl>   <dbl>   <dbl>    <int>
## 1    50.3      28      77      361

Age distribution in group 2 (thalch <= 140):

as_tibble(

  hd_clean %>% 
    select(age, thalach) %>% 
    filter(thalach <= 140) %>% 
    summarise(avg_age = mean(age),
              min_age = min(age),
              max_age = max(age),
              patients = n()))
## # A tibble: 1 x 4
##   avg_age min_age max_age patients
##     <dbl>   <dbl>   <dbl>    <int>
## 1    55.8      32      77      379

Now it makes sense. Group 1 has got a relatively younger population and that’s why they have lower presence of heart disease in-spite of having higher heart rates.

Next, we will be looking at exang which represents Exercise Induced Angina (Chest pain experienced by a patient during exercise). This field has the following values:

  • Value 1 : yes
  • Value 0 : no

Lets look at the prevalance of heart disease in patients who experienced exercise induced angina vs those who did not:

summ_ang = hd_clean %>% mutate(exang = as.factor(exang)) %>% mutate(exang = case_when(
  exang == 1 ~ "Yes",
  exang == 0 ~ "No"
)) %>% 
  group_by(exang,hd_ind) %>% summarise(n = n()) %>% pivot_wider(names_from = hd_ind, values_from = n) 

colnames(summ_ang) = c("Exang","No_Disease", "Disease")

fig10 = plot_ly(summ_ang, 
              x = ~summ_ang$Exang, 
              y = ~summ_ang$No_Disease, 
              type = 'bar', 
              name = 'No Disease',
              marker = list(color = c("#34eb89"))) %>% 
  add_trace(y = ~summ_ang$Disease, 
            name = 'Disease',
            marker = list(color = c("#eb4034"))) %>% 
  layout(yaxis = list(title = 'Count'),
         title = "Presence of Heart Disease across Exang groups",
         xaxis = list(title = 'Exang'), barmode = 'stack')

fig10
  • exang looks like an important predictor in determining whether a patient will suffer from heart disease or not

  • We can clearly see from the chart above that the proportion of patients experiencing heart disease is higher in the category that experienced Exercise Induced Angina

Next variable in line is oldpeakwhich represents ST depression induced by exercise relative to rest. Let’s look at the distribution of this variable across the disease categories.

chart_title = "Histogram of oldpeak (Heart Disease vs. No Heart Disease)"

fig10 = plot_hist_funct(dis_grp$oldpeak,
                       no_dis_grp$oldpeak,
                       chart_title,
                       "oldpeak",
                       "count")
fig10

Lastly we will look at the presence of heart disease at different locations.

summ_loc = hd_clean %>% mutate(loc = as.factor(location)) %>% 
  group_by(loc,hd_ind) %>% summarise(n = n()) %>% pivot_wider(names_from = hd_ind, values_from = n) 

colnames(summ_loc) = c("Location","No_Disease", "Disease")

fig10 = plot_ly(summ_loc, 
              x = ~summ_loc$Location, 
              y = ~summ_loc$No_Disease, 
              type = 'bar', 
              name = 'No Disease',
              marker = list(color = c("#34eb89"))) %>% 
  add_trace(y = ~summ_loc$Disease, 
            name = 'Disease',
            marker = list(color = c("#eb4034"))) %>% 
  layout(yaxis = list(title = 'Count'),
         title = "Presence of Heart Disease across Locations",
         xaxis = list(title = 'Location'), barmode = 'stack')

fig10
  • We will not be including location in our model since it does not have any inherent meaning in our analysis context

From the entire analysis, done above we have the following key takeaways:

  • Location is not important for our analysis

  • Since the cholesterol column has a lot of zeroes we will be dropping it from our analysis. Our analysis also revealed that cholesterol is not a strong predictor

Modeling

For the modeling purpose we will be using the cleaned up version of the data set. We will be doing some additional manipulation based on our analysis above followed by an estimation, validation and test split of our hd_clean data set.

# Dropping location and cholesterol columns

hd_clean = hd_clean %>% select(-location, -chol)

# Converting columns to factors
hd_clean = hd_clean %>% mutate(hd_ind = as.factor(hd_ind),
                               sex = as.factor(sex),
                               cp = as.factor(cp),
                               fbs = as.factor(fbs),
                               restecg = as.factor(restecg),
                               exang = as.factor(exang))
set.seed(42)
# doing a test train split
hd_trn_idx = sample(nrow(hd_clean),size = 0.8*nrow(hd_clean))
hd_trn = hd_clean[hd_trn_idx,]
hd_tst = hd_clean[-hd_trn_idx,]

set.seed(42)
# Doing an estimation validation split
hd_est_idx = sample(nrow(hd_trn),size = 0.8*nrow(hd_trn))
hd_est = hd_trn[hd_est_idx,]
hd_val = hd_trn[-hd_est_idx,]


print(paste("Rows (Estimation Data):", nrow(hd_est), "Columns (Estimation Data):",ncol(hd_est)))
## [1] "Rows (Estimation Data): 473 Columns (Estimation Data): 10"
print(paste("Rows (Validation Data):", nrow(hd_val), "Columns (Validation Data):",ncol(hd_val)))
## [1] "Rows (Validation Data): 119 Columns (Validation Data): 10"
print(paste("Rows (Test Data):", nrow(hd_tst), "Columns (Test Data):",ncol(hd_tst)))
## [1] "Rows (Test Data): 148 Columns (Test Data): 10"

Before we begin fitting different models to our estimation data set let’s establish a baseline against which our model will be compared.

We know that there is a healthy class balance in our data set, so we will be using Accuracy as the benchmarking metric for our analysis.

Benchmarking

acc = round(mean(hd_est$hd_ind == 1),4)*100
print(paste("If we just predict the majority class our accuracy would be: ", 
            acc,"%"
            ))
## [1] "If we just predict the majority class our accuracy would be:  52.85 %"

There is our benchmark. Any model that we build from here on needs to have a better accuracy than 52.85.

We will be fitting the following models to our estimation data:

  • Logistic Regression
  • Logistic Regression with Ridge Penalty
  • Random Forest
  • Naive Bayes Classifier
  • Ensemble (of the above 4 algorithms), Majority Voting

Lets start with fitting a Logistic Regression model on our data. After fitting the model we get the following model summary.

log_mod = glm(hd_ind ~., data = hd_est, family = "binomial")
summary(log_mod)
## 
## Call:
## glm(formula = hd_ind ~ ., family = "binomial", data = hd_est)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.8555  -0.6381   0.2496   0.6466   2.3106  
## 
## Coefficients:
##              Estimate Std. Error z value Pr(>|z|)    
## (Intercept) -2.013580   1.668590  -1.207  0.22753    
## age          0.018148   0.015342   1.183  0.23684    
## sex1         1.334417   0.291395   4.579 4.66e-06 ***
## cp2         -0.632939   0.622810  -1.016  0.30950    
## cp3          0.155136   0.581698   0.267  0.78970    
## cp4          1.455751   0.569216   2.557  0.01054 *  
## trestbps     0.003722   0.006889   0.540  0.58895    
## fbs1         0.851555   0.381680   2.231  0.02568 *  
## restecg1     0.179783   0.372912   0.482  0.62973    
## restecg2    -0.068652   0.308004  -0.223  0.82362    
## thalach     -0.013431   0.005744  -2.338  0.01937 *  
## exang1       0.777984   0.284755   2.732  0.00629 ** 
## oldpeak      0.600683   0.142529   4.214 2.50e-05 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 654.18  on 472  degrees of freedom
## Residual deviance: 408.52  on 460  degrees of freedom
## AIC: 434.52
## 
## Number of Fisher Scoring iterations: 5

We see from the model summary that all the variables are not significant. Removing the non significant variable and rebuilding the model.

# Fitting logistic model with significant variables
log_mod = glm(hd_ind ~ 
                age +
                sex + 
                cp + 
                thalach + 
                exang + 
                oldpeak, 
              data = hd_est,
              family = "binomial")

summary(log_mod)
## 
## Call:
## glm(formula = hd_ind ~ age + sex + cp + thalach + exang + oldpeak, 
##     family = "binomial", data = hd_est)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6352  -0.6379   0.2545   0.6702   2.4524  
## 
## Coefficients:
##              Estimate Std. Error z value Pr(>|z|)    
## (Intercept) -1.621369   1.438073  -1.127  0.25955    
## age          0.024496   0.014463   1.694  0.09031 .  
## sex1         1.359903   0.288121   4.720 2.36e-06 ***
## cp2         -0.552475   0.620372  -0.891  0.37317    
## cp3          0.144574   0.583173   0.248  0.80420    
## cp4          1.461028   0.571252   2.558  0.01054 *  
## thalach     -0.014496   0.005596  -2.590  0.00958 ** 
## exang1       0.823938   0.278378   2.960  0.00308 ** 
## oldpeak      0.580761   0.140232   4.141 3.45e-05 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 654.18  on 472  degrees of freedom
## Residual deviance: 414.86  on 464  degrees of freedom
## AIC: 432.86
## 
## Number of Fisher Scoring iterations: 5

Now lets check the sensitivity, specificity and accuracy plot for different probability thresholds for the logistic model

prob_thresholds = seq(0,1,0.01)
mod_out_log = sapply(prob_thresholds, Model_Diagnostics, model_object = log_mod, val_dat = hd_val)
log_plot = spec_sense_plot(mod_out_log)
log_plot

We see that the three curves intersect at 0.52 threshold. Based on this we will be predicting any entry with probability > 0.52 as 1 and < 0.52 as 0

Printing confusion Matrix and accuracy for the logistic model

log_mod_pred = ifelse(predict(log_mod, hd_val, type = "response") > 0.5, 1, 0)
print(confusion_matrix(log_mod_pred))
##          Predicted 1 Predicted 0
## Actual 1          55          13
## Actual 0           8          43
log_mod_acc = round(mean(log_mod_pred == hd_val$hd_ind),3)
print(paste("Validation Accuracy: ",log_mod_acc))
## [1] "Validation Accuracy:  0.824"

Next we will fit a Logistic Model with Ridge Penalty on our data. After fitting the model we get the following model summary.

# Fitting a logistic model with Ridge Penalty
x_mat = data.matrix(hd_est %>% select(-hd_ind))
y = hd_est$hd_ind
set.seed(42)
log_mod_ridge = cv.glmnet(x_mat, y, alpha = 0,family = "binomial")
coefficients(log_mod_ridge)
## 10 x 1 sparse Matrix of class "dgCMatrix"
##                        1
## (Intercept) -4.183194501
## age          0.015858203
## sex          0.713159294
## cp           0.469361263
## trestbps     0.003503962
## fbs          0.413613677
## restecg      0.029119256
## thalach     -0.010639747
## exang        0.665922993
## oldpeak      0.334197466

Let’s find out the proper probability threshold for the Logistic Model (with Ridge Penalty) by plotting a specificity, sensitivity and accuracy chart.

mod_out_log_ridge = sapply(prob_thresholds, Model_Diagnostics, model_object = log_mod_ridge, val_dat = hd_val)
log_ridge_plot = spec_sense_plot(mod_out_log_ridge)
log_ridge_plot

Based on the chart above we observe that accuracy hits a maximum of ~78% around 0.54 threshold. Lets print the confusion Matrix and accuracy of the model below.

log_ridge_pred = ifelse(predict(log_mod_ridge, 
                                          data.matrix(hd_val %>% select(-hd_ind)), 
                                          type = "response", 
                                          s = "lambda.1se") > 0.54, 1,0)
print(confusion_matrix(log_ridge_pred))
##          Predicted 1 Predicted 0
## Actual 1          53          10
## Actual 0          10          46
log_ridge_acc = mean(log_ridge_pred == hd_val$hd_ind)
print(paste("Validation Accuracy: ",round(log_ridge_acc,3)))
## [1] "Validation Accuracy:  0.832"

Next we will fit a Naive Bayes model on our data. After fitting the model we get the following model summary.

Naive_Bayes_Mod = klaR::NaiveBayes(hd_ind ~., data = hd_est)
summary(Naive_Bayes_Mod)
##           Length Class      Mode     
## apriori   2      table      numeric  
## tables    9      -none-     list     
## levels    2      -none-     character
## call      3      -none-     call     
## x         9      data.frame list     
## usekernel 1      -none-     logical  
## varnames  9      -none-     character

Let’s find out the proper probability threshold for the Naive Bayes Model by plotting a specificity, sensitivity and accuracy chart

mod_out = sapply(prob_thresholds, Model_Diagnostics, model_object = Naive_Bayes_Mod, val_dat = hd_val)
nb_plot = spec_sense_plot(mod_out)
nb_plot

Based on the chart above we observe that accuracy hits a maximum of ~78% around 0.57 threshold. Lets print the confusion Matrix and accuracy of the model below

nb_pred = ifelse(predict(Naive_Bayes_Mod, hd_val)$posterior[,'1'] > 0.57,1,0)
print(confusion_matrix(nb_pred))
##          Predicted 1 Predicted 0
## Actual 1          51           8
## Actual 0          12          48
nb_acc = round(mean(nb_pred == hd_val$hd_ind),3)
print(paste("Validation Accuracy: ",nb_acc))
## [1] "Validation Accuracy:  0.832"

Next we will fit a Random Forest model on our data. After fitting the model lets plot the error rates to identify where the OOB Error Rate Stabilizes.

set.seed(42)
# Lets fit a random forest model with ntree = 600  and plot the error rates
rf_mod = randomForest(hd_ind ~., data = hd_est, mtry = 3, ntree = 1000)

# Plotting the error
ntree = seq(1:1000)
rf_error_df = as.data.frame(rf_mod$err.rate)
fig = plot_ly(rf_error_df, 
              x = ~ntree, 
              y = ~OOB, name = 'OOB Error', type = 'scatter', mode = 'lines') %>% 
  add_trace(y = ~rf_error_df$`0`, name = '0 Misclassification', mode = 'lines') %>% 
  add_trace(y = ~rf_error_df$`1`, name = '1 Misclassification', mode = 'lines') %>% 
  layout(xaxis = list(title = 'Number of trees'),
         yaxis = list(title = 'Error Rate'))

fig

We see that the error stabilizes around ntree = 700. Now lets plot the error rate against varying mtry parameter.

mtry_values = seq(1:(ncol(hd_est)-1))

fit_rf = function(df,mtry){
  fit_random = randomForest(hd_ind ~., data = df, mtry = mtry, ntree = 700)
  fit_random$err.rate[nrow(fit_random$err.rate),1]
}

error_rate = sapply(mtry_values, fit_rf, df = hd_est)
error_df = data.frame(error_rate)
error_df$mtry = seq(1:(ncol(hd_est)-1))

fig = plot_ly(
  error_df,
  x = ~error_df$mtry,
  y = ~error_df$error_rate,
  name = "OOB Error",
  type = "scatter", mode = 'lines'
  
) %>% layout(
  xaxis = list(title = 'Number of Randomly Selected Variables'),
  yaxis = list(title = 'Error Rate')
)

fig

Based on the chart above we see that the minimum OOB Error rate is at mtry = 1. Now lets train a Random Forest Model.

# Fitting a random forest model
set.seed(42)
# Lets fit a random forest model with ntree = 600  and plot the error rates
rf_mod = randomForest(hd_ind ~., data = hd_est, mtry = 1, ntree = 700)

rf_mod$importance
##          MeanDecreaseGini
## age             15.209609
## sex              8.984770
## cp              25.398502
## trestbps        11.380133
## fbs              3.609824
## restecg          5.020334
## thalach         19.449458
## exang           15.595407
## oldpeak         19.610461

Now let’s find out the probability threshold of the Random Forest Model for classification.

mod_out_rf = sapply(prob_thresholds, Model_Diagnostics, model_object = rf_mod, val_dat = hd_val)
rf_plot = spec_sense_plot(mod_out_log_ridge)
rf_plot

The probability threshold from the above chart for the Random Forest model is 0.54.

rf_pred = ifelse(predict(rf_mod, hd_val, type = "prob")[,'1'] > 0.54, 1,0)
print(confusion_matrix(rf_pred))
##          Predicted 1 Predicted 0
## Actual 1          53          11
## Actual 0          10          45
rf_acc = mean(rf_pred == hd_val$hd_ind)
print(paste("Validation Accuracy: ",round(rf_acc,3)))
## [1] "Validation Accuracy:  0.824"

After building the individual models we will be doing an ensemble. The ensemble technique goes something like this.

  • We consider a record as 0 only if it is classified in that way by all the four model

Based on this technique we will calculate the accuracy and the cofusion matrix again.

ensemble_pred = cbind(log = log_mod_pred,
                      log_ridge = log_ridge_pred,
                      nb_pred = nb_pred,
                      rf_pred = rf_pred)

colnames(ensemble_pred) = c("log", "log_ridge", "nb_pred","rf_pred")

fin_pred = apply(ensemble_pred, 1, function(x){sum(x)})
ensemble_acc = mean(ifelse(fin_pred >=1,1,0) == hd_val$hd_ind)

print(confusion_matrix(ifelse(fin_pred >=1,1,0)))
##          Predicted 1 Predicted 0
## Actual 1          55          15
## Actual 0           8          41
print(paste("Validation Accuracy: ",round(ensemble_acc,3)))
## [1] "Validation Accuracy:  0.807"

Plotting the accuracies in the chart below, for model performance comparison:

fig <- plot_ly(
  x = c("Logistic", "Logistic with Ridge Penalty", "Naive Bayes", "Random Forest", "Ensemble"),
  y = c(log_mod_acc, log_ridge_acc, nb_acc, rf_acc, ensemble_acc),
  type = "bar"
) %>% layout(xaxis = list(title = 'Model'),
             yaxis = list(title = 'Accuracy'),
             title = "Model Accuracy(Validation)")

fig
  • From the chart above we see that all the algorithms are giving us equal performance with the ensemble method performing slightly better than all the model

  • All the models have a significantly better performance than the benchmark


Results

Let’s check the prediction on the test data set and see the accuracy of each of the models.

# Making predictions on the test dataset
log_tst_pred = ifelse(predict(log_mod, hd_tst, type = "response") > 0.5, 1,0)
log_tst_acc = mean(log_tst_pred == hd_tst$hd_ind)

log_ridge_tst_pred = ifelse(predict(log_mod_ridge, data.matrix(hd_tst %>% select(-hd_ind)), type = "response", "lambda.1se") > 0.54, 1,0)
log_ridge_tst_acc = mean(log_ridge_tst_pred == hd_tst$hd_ind)

rf_tst_pred = ifelse(predict(rf_mod, hd_tst, type = "prob")[,'1'] > 0.54, 1,0)
rf_tst_acc = mean(rf_tst_pred == hd_tst$hd_ind)

nb_tst_pred = ifelse(predict(Naive_Bayes_Mod, hd_tst)$posterior[,'1'] > 0.57,1,0)
nb_tst_acc =  mean(nb_tst_pred == hd_tst$hd_ind)

ensemble_pred_tst = cbind(log = log_tst_pred,
                          log_ridge = log_ridge_tst_pred,
                          nb_pred = nb_tst_pred,
                          rf_pred = rf_tst_pred)

colnames(ensemble_pred_tst) = c("log", "log_ridge", "nb_pred","rf_pred")

fin_pred_tst = ifelse(apply(ensemble_pred_tst, 1, function(x){sum(x)}) >=1,1,0)
ensemble_tst_acc = mean(fin_pred_tst == hd_tst$hd_ind)


fig <- plot_ly(
  x = c("Logistic", "Logistic with Ridge Penalty", "Naive Bayes", "Random Forest", "Ensemble"),
  y = c(log_tst_acc, log_ridge_tst_acc, nb_tst_acc, rf_tst_acc, ensemble_tst_acc),
  type = "bar"
) %>% layout(xaxis = list(title = 'Model'),
             yaxis = list(title = 'Accuracy'),
             title = "Model Test Accuracy (Test)")

fig

Discussion

Okay, since we have different models that tells us whether a patient will have/not have heart disease with ~80% accuracy, we can use the predictions from the model to assist doctors / medical practitioners in quick diagnosis. We can give predictions from each of these models to the doctor along with the patient report. If the patient is being classified as a diseased patient, there is a good chance that he has a heart disease at which point the doctor can go an look at the patient records in depth


Appendix

Data Dictionary for this analysis can be found below:

Variable Description
age Age of the patient
sex Patient Gender
cp Chest Pain Type
trestbps Resting Blood Pressure
chol Cholesterol Level
fbs Fasting Blood Sugar Level
restecg Resting Electrocardiographic Results
thalch Heart rate achieved during exercise
exang Exercise Induced Angina
oldpeak ST depression induced by exercise relative to rest
location Location from which the pateint sample was collected